import os
import re
import json
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from typing import Optional, List, Dict, Tuple

plt.rcParams.update({'font.size': 14})
AXIS_LABEL_FONTSIZE = 28
TITLE_FONTSIZE = 28 
LEGEND_FONTSIZE = 20
TICK_LABEL_FONTSIZE = 28
SUBPLOT_TITLE_FONTSIZE = 24

PM_EXPERIMENT_NAME_LABEL = "PMxPM"
DPM_EXPERIMENT_NAME_LABEL = "DPMxDPM"

PM_EXPERIMENT_FILENAME_KEY = "PMxPM"
DA_EXPERIMENT_FILENAME_KEY = "DAxDA"

INPUT_PM_DATA_DIR = Path(f"{PM_EXPERIMENT_FILENAME_KEY}_visualizations_combined_denoised")
INPUT_PM_CSV_BASENAME = f"{PM_EXPERIMENT_FILENAME_KEY}_all_runs_processed_metrics.csv"
INPUT_PROCESSED_METRICS_FILE_PM = INPUT_PM_DATA_DIR / INPUT_PM_CSV_BASENAME

INPUT_DA_DATA_DIR = Path(f"{DA_EXPERIMENT_FILENAME_KEY}_visualizations_combined_denoised")
INPUT_DA_CSV_BASENAME = f"{DA_EXPERIMENT_FILENAME_KEY}_all_runs_processed_metrics.csv"
INPUT_PROCESSED_METRICS_FILE_DA = INPUT_DA_DATA_DIR / INPUT_DA_CSV_BASENAME

OUTPUT_COMBINED_VIS_DIR = Path(f"Combined_Experiments_Plots_SEM_{PM_EXPERIMENT_NAME_LABEL}_vs_{DPM_EXPERIMENT_NAME_LABEL}")
OUTPUT_COMBINED_VIS_DIR.mkdir(parents=True, exist_ok=True)

GLOBAL_OSAI_Y_MAX = 10 

PM_PLOT_COLOR = 'green'
PM_PLOT_MARKER = 'o'
DPM_PLOT_COLOR = 'red'
DPM_PLOT_MARKER = 'x'

STRATEGIES_FOR_COMBINED_PLOT = [
    "Independent_Development", 
    "Counter_Measure", 
    "Exploitation_Attempt"
]

def plot_combined_osai(
    df_pm: pd.DataFrame, 
    df_dpm: pd.DataFrame,
    total_runs_pm_func: int,
    total_runs_dpm_func: int,
    y_max: Optional[float] = None
):
    if df_pm.empty or df_dpm.empty:
        print("Cannot plot combined OSAI: One or both DataFrames are empty.")
        return

    df_pm['osai_avg_pair'] = (df_pm['osai_A'] + df_pm['osai_B']) / 2
    avg_osai_pm = df_pm.groupby("meta_round")['osai_avg_pair'].mean().reset_index()
    sem_osai_pm = df_pm.groupby("meta_round")['osai_avg_pair'].sem().reset_index()
    avg_osai_pm.rename(columns={'osai_avg_pair': 'mean_osai'}, inplace=True)
    sem_osai_pm.rename(columns={'osai_avg_pair': 'sem_osai'}, inplace=True)

    df_dpm['osai_avg_pair'] = (df_dpm['osai_A'] + df_dpm['osai_B']) / 2
    avg_osai_dpm = df_dpm.groupby("meta_round")['osai_avg_pair'].mean().reset_index()
    sem_osai_dpm = df_dpm.groupby("meta_round")['osai_avg_pair'].sem().reset_index()
    avg_osai_dpm.rename(columns={'osai_avg_pair': 'mean_osai'}, inplace=True)
    sem_osai_dpm.rename(columns={'osai_avg_pair': 'sem_osai'}, inplace=True)

    plt.figure(figsize=(12, 7)) 
    
    pm_label = f"{PM_EXPERIMENT_NAME_LABEL} (Avg, N={total_runs_pm_func} ± SEM)"
    plt.plot(avg_osai_pm["meta_round"], avg_osai_pm["mean_osai"], 
             marker=PM_PLOT_MARKER, linestyle='-', label=pm_label, color=PM_PLOT_COLOR)
    plt.fill_between(avg_osai_pm["meta_round"], 
                     avg_osai_pm["mean_osai"] - sem_osai_pm["sem_osai"], 
                     avg_osai_pm["mean_osai"] + sem_osai_pm["sem_osai"], 
                     color=PM_PLOT_COLOR, alpha=0.2)

    dpm_label = f"{DPM_EXPERIMENT_NAME_LABEL} (Avg, N={total_runs_dpm_func} ± SEM)"
    plt.plot(avg_osai_dpm["meta_round"], avg_osai_dpm["mean_osai"], 
             marker=DPM_PLOT_MARKER, linestyle='--', label=dpm_label, color=DPM_PLOT_COLOR)
    plt.fill_between(avg_osai_dpm["meta_round"], 
                     avg_osai_dpm["mean_osai"] - sem_osai_dpm["sem_osai"], 
                     avg_osai_dpm["mean_osai"] + sem_osai_dpm["sem_osai"], 
                     color=DPM_PLOT_COLOR, alpha=0.2)
    
    plt.title(f"OSAI ({PM_EXPERIMENT_NAME_LABEL} vs {DPM_EXPERIMENT_NAME_LABEL})", fontsize=TITLE_FONTSIZE)
    plt.xlabel("Meta-Round", fontsize=AXIS_LABEL_FONTSIZE)
    plt.ylabel("OSAI Score (Keywords)", fontsize=AXIS_LABEL_FONTSIZE)
    
    max_meta_round_pm = 1
    if not df_pm.empty and "meta_round" in df_pm.columns and not df_pm["meta_round"].empty:
        max_meta_round_pm = int(df_pm["meta_round"].max())
    max_meta_round_dpm = 1
    if not df_dpm.empty and "meta_round" in df_dpm.columns and not df_dpm["meta_round"].empty:
        max_meta_round_dpm = int(df_dpm["meta_round"].max())
    max_meta_round = max(max_meta_round_pm, max_meta_round_dpm)

    plt.xticks(range(1, max_meta_round + 1), fontsize=TICK_LABEL_FONTSIZE)
    plt.yticks(fontsize=TICK_LABEL_FONTSIZE)
    
    final_y_bottom = 0 
    if y_max is not None:
        final_y_top = y_max
    else:
        current_ymin, current_ymax = plt.gca().get_ylim()
        final_y_top = current_ymax 
        max_sem_upper_bound = -np.inf
        if not sem_osai_pm.empty and 'sem_osai' in sem_osai_pm.columns:
            max_sem_upper_bound = max(max_sem_upper_bound, (avg_osai_pm["mean_osai"] + sem_osai_pm["sem_osai"]).max())
        if not sem_osai_dpm.empty and 'sem_osai' in sem_osai_dpm.columns:
            max_sem_upper_bound = max(max_sem_upper_bound, (avg_osai_dpm["mean_osai"] + sem_osai_dpm["sem_osai"]).max())
        if max_sem_upper_bound != -np.inf:
             final_y_top = max(final_y_top, max_sem_upper_bound * 1.05) 
    
    plt.ylim(final_y_bottom, final_y_top)
    plt.legend(fontsize=LEGEND_FONTSIZE)
    plt.grid(True, linestyle=':', alpha=0.7)
    plt.tight_layout()
    
    plot_filename_base = OUTPUT_COMBINED_VIS_DIR / f"Combined_Avg_OSAI_{PM_EXPERIMENT_NAME_LABEL}_vs_{DPM_EXPERIMENT_NAME_LABEL}_with_SEM"
    plt.savefig(f"{plot_filename_base}.png")
    plt.savefig(f"{plot_filename_base}.pdf")
    print(f"Saved Combined OSAI plot with SEM to {plot_filename_base}.png (and .pdf)")
    plt.close()

def plot_combined_strategic_responses_multipanel(
    df_pm: pd.DataFrame, 
    df_dpm: pd.DataFrame, 
    total_runs_pm_func: int, 
    total_runs_dpm_func: int, 
    strategies_to_plot: List[str]
):
    if df_pm.empty or df_dpm.empty:
        print("Cannot plot combined strategic responses: One or both DataFrames are empty.")
        return

    num_strategies = len(strategies_to_plot)
    if num_strategies == 0:
        print("No strategies provided for multi-panel plot.")
        return

    proportion_plot_ylim = (-0.02, 1.02) 
    
    max_meta_round_pm = 1
    if "meta_round" in df_pm.columns and not df_pm["meta_round"].empty:
        max_meta_round_pm = int(df_pm["meta_round"].max())
    max_meta_round_dpm = 1 
    if "meta_round" in df_dpm.columns and not df_dpm["meta_round"].empty:
        max_meta_round_dpm = int(df_dpm["meta_round"].max())
    max_meta_round = max(max_meta_round_pm, max_meta_round_dpm)

    fig, axes = plt.subplots(1, num_strategies, figsize=(7 * num_strategies, 6.0), sharey=True) 
    if num_strategies == 1: 
        axes = [axes]

    for i, strategy_name in enumerate(strategies_to_plot):
        ax = axes[i]
        
        if 'classification_A' not in df_pm.columns or 'classification_B' not in df_pm.columns:
            print(f"Warning: Classification columns not found in {PM_EXPERIMENT_NAME_LABEL} data for strategy {strategy_name}.")
            mean_prop_pm = pd.DataFrame(columns=["meta_round", "mean_prop"]) 
            sem_prop_pm = pd.DataFrame(columns=["meta_round", "sem_prop"]) 
        else:
            df_pm[f'is_A_{strategy_name}'] = (df_pm['classification_A'] == strategy_name).astype(int)
            df_pm[f'is_B_{strategy_name}'] = (df_pm['classification_B'] == strategy_name).astype(int)
            df_pm[f'pair_avg_prop_{strategy_name}'] = (df_pm[f'is_A_{strategy_name}'] + df_pm[f'is_B_{strategy_name}']) / 2.0
            
            mean_prop_pm = df_pm.groupby("meta_round")[f'pair_avg_prop_{strategy_name}'].mean().reset_index()
            sem_prop_pm = df_pm.groupby("meta_round")[f'pair_avg_prop_{strategy_name}'].sem().reset_index()
            mean_prop_pm.rename(columns={f'pair_avg_prop_{strategy_name}': 'mean_prop'}, inplace=True)
            sem_prop_pm.rename(columns={f'pair_avg_prop_{strategy_name}': 'sem_prop'}, inplace=True)

        if 'classification_A' not in df_dpm.columns or 'classification_B' not in df_dpm.columns:
            print(f"Warning: Classification columns not found in {DPM_EXPERIMENT_NAME_LABEL} data for strategy {strategy_name}.")
            mean_prop_dpm = pd.DataFrame(columns=["meta_round", "mean_prop"])
            sem_prop_dpm = pd.DataFrame(columns=["meta_round", "sem_prop"])
        else:
            df_dpm[f'is_A_{strategy_name}'] = (df_dpm['classification_A'] == strategy_name).astype(int)
            df_dpm[f'is_B_{strategy_name}'] = (df_dpm['classification_B'] == strategy_name).astype(int)
            df_dpm[f'pair_avg_prop_{strategy_name}'] = (df_dpm[f'is_A_{strategy_name}'] + df_dpm[f'is_B_{strategy_name}']) / 2.0

            mean_prop_dpm = df_dpm.groupby("meta_round")[f'pair_avg_prop_{strategy_name}'].mean().reset_index()
            sem_prop_dpm = df_dpm.groupby("meta_round")[f'pair_avg_prop_{strategy_name}'].sem().reset_index()
            mean_prop_dpm.rename(columns={f'pair_avg_prop_{strategy_name}': 'mean_prop'}, inplace=True)
            sem_prop_dpm.rename(columns={f'pair_avg_prop_{strategy_name}': 'sem_prop'}, inplace=True)

        pm_strat_label = f"{PM_EXPERIMENT_NAME_LABEL} (Avg, N={total_runs_pm_func} ± SEM)"
        if not mean_prop_pm.empty and 'mean_prop' in mean_prop_pm.columns:
            ax.plot(mean_prop_pm["meta_round"], mean_prop_pm["mean_prop"], 
                    marker=PM_PLOT_MARKER, linestyle='-', label=pm_strat_label, color=PM_PLOT_COLOR)
            if not sem_prop_pm.empty and 'sem_prop' in sem_prop_pm.columns:
                 ax.fill_between(mean_prop_pm["meta_round"], 
                                 mean_prop_pm["mean_prop"] - sem_prop_pm["sem_prop"], 
                                 mean_prop_pm["mean_prop"] + sem_prop_pm["sem_prop"], 
                                 color=PM_PLOT_COLOR, alpha=0.2)
        else: 
            ax.plot([],[], marker=PM_PLOT_MARKER, linestyle='-', label=pm_strat_label, color=PM_PLOT_COLOR)

        dpm_strat_label = f"{DPM_EXPERIMENT_NAME_LABEL} (Avg, N={total_runs_dpm_func} ± SEM)"
        if not mean_prop_dpm.empty and 'mean_prop' in mean_prop_dpm.columns:
            ax.plot(mean_prop_dpm["meta_round"], mean_prop_dpm["mean_prop"], 
                    marker=DPM_PLOT_MARKER, linestyle='--', label=dpm_strat_label, color=DPM_PLOT_COLOR)
            if not sem_prop_dpm.empty and 'sem_prop' in sem_prop_dpm.columns:
                ax.fill_between(mean_prop_dpm["meta_round"], 
                                 mean_prop_dpm["mean_prop"] - sem_prop_dpm["sem_prop"], 
                                 mean_prop_dpm["mean_prop"] + sem_prop_dpm["sem_prop"], 
                                 color=DPM_PLOT_COLOR, alpha=0.2)
        else:
            ax.plot([],[], marker=DPM_PLOT_MARKER, linestyle='--', label=dpm_strat_label, color=DPM_PLOT_COLOR)

        title_strategy_name = strategy_name.replace("_", " ")
        ax.set_title(title_strategy_name, fontsize=SUBPLOT_TITLE_FONTSIZE)
        ax.set_xlabel("Meta-Round", fontsize=AXIS_LABEL_FONTSIZE -2) 
        if i == 0: 
            ax.set_ylabel("Proportion", fontsize=AXIS_LABEL_FONTSIZE -2)
            ax.legend(loc='upper right', fontsize=LEGEND_FONTSIZE - 4, ncol=1)
        
        ax.set_xticks(range(1, max_meta_round + 1))
        ax.tick_params(axis='x', labelsize=TICK_LABEL_FONTSIZE -2)
        ax.tick_params(axis='y', labelsize=TICK_LABEL_FONTSIZE -2)
        ax.set_ylim(proportion_plot_ylim) 
        ax.grid(True, linestyle=':', alpha=0.7)
    
    plt.tight_layout(rect=[0, 0.03, 1, 0.98])
    
    plot_filename_base = OUTPUT_COMBINED_VIS_DIR / f"Combined_Multipanel_Strategic_Responses_{PM_EXPERIMENT_NAME_LABEL}_vs_{DPM_EXPERIMENT_NAME_LABEL}_with_SEM"
    plt.savefig(f"{plot_filename_base}.png", bbox_inches='tight')
    plt.savefig(f"{plot_filename_base}.pdf", bbox_inches='tight')
    print(f"Saved Combined Multi-panel Strategic Responses plot with SEM to {plot_filename_base}.png (and .pdf)")
    plt.close(fig)

if __name__ == "__main__":
    print("Starting Combined Plotting Script...")
    print(f"Output directory for combined plots: {OUTPUT_COMBINED_VIS_DIR}")

    main_df_pm = pd.DataFrame()
    total_num_runs_pm = 0
    if not INPUT_PROCESSED_METRICS_FILE_PM.exists():
        print(f"\nERROR: {PM_EXPERIMENT_NAME_LABEL} Input CSV file not found: {INPUT_PROCESSED_METRICS_FILE_PM}")
    else:
        try:
            main_df_pm = pd.read_csv(INPUT_PROCESSED_METRICS_FILE_PM)
            print(f"Successfully loaded {PM_EXPERIMENT_NAME_LABEL} data from: {INPUT_PROCESSED_METRICS_FILE_PM} ({len(main_df_pm)} rows)")
            if 'run' in main_df_pm.columns:
                main_df_pm['run'] = main_df_pm['run'].astype(int)
                total_num_runs_pm = main_df_pm['run'].nunique()
                print(f"Total unique runs in {PM_EXPERIMENT_NAME_LABEL} data: {total_num_runs_pm}")
            else:
                print(f"Warning: 'run' column not found in {PM_EXPERIMENT_NAME_LABEL} data. Total runs set to 0.")
        except Exception as e:
            print(f"Error loading {PM_EXPERIMENT_NAME_LABEL} CSV file {INPUT_PROCESSED_METRICS_FILE_PM}: {e}")
            main_df_pm = pd.DataFrame() 

    main_df_dpm = pd.DataFrame()
    total_num_runs_dpm = 0
    if not INPUT_PROCESSED_METRICS_FILE_DA.exists():
        print(f"\nERROR: {DPM_EXPERIMENT_NAME_LABEL} (source: {DA_EXPERIMENT_FILENAME_KEY}) Input CSV file not found: {INPUT_PROCESSED_METRICS_FILE_DA}")
    else:
        try:
            main_df_dpm = pd.read_csv(INPUT_PROCESSED_METRICS_FILE_DA)
            print(f"Successfully loaded {DPM_EXPERIMENT_NAME_LABEL} data (from {DA_EXPERIMENT_FILENAME_KEY} file): {INPUT_PROCESSED_METRICS_FILE_DA} ({len(main_df_dpm)} rows)")
            if 'run' in main_df_dpm.columns:
                main_df_dpm['run'] = main_df_dpm['run'].astype(int)
                total_num_runs_dpm = main_df_dpm['run'].nunique()
                print(f"Total unique runs in {DPM_EXPERIMENT_NAME_LABEL} data: {total_num_runs_dpm}")
            else:
                print(f"Warning: 'run' column not found in {DPM_EXPERIMENT_NAME_LABEL} data. Total runs set to 0.")
        except Exception as e:
            print(f"Error loading {DPM_EXPERIMENT_NAME_LABEL} (source: {DA_EXPERIMENT_FILENAME_KEY}) CSV file {INPUT_PROCESSED_METRICS_FILE_DA}: {e}")
            main_df_dpm = pd.DataFrame() 
    
    if main_df_pm.empty and main_df_dpm.empty:
        print(f"\nBoth {PM_EXPERIMENT_NAME_LABEL} and {DPM_EXPERIMENT_NAME_LABEL} dataframes are empty. Nothing to plot. Exiting.")
        exit()
    if main_df_pm.empty:
        print(f"\nWarning: {PM_EXPERIMENT_NAME_LABEL} dataframe is empty. Combined plots involving {PM_EXPERIMENT_NAME_LABEL} might be incomplete or fail.")
    if main_df_dpm.empty:
        print(f"\nWarning: {DPM_EXPERIMENT_NAME_LABEL} dataframe is empty. Combined plots involving {DPM_EXPERIMENT_NAME_LABEL} might be incomplete or fail.")

    print("\n--- Generating Combined Visualizations with SEM ---")
    
    if not main_df_pm.empty and not main_df_dpm.empty:
        plot_combined_osai(main_df_pm.copy(), main_df_dpm.copy(), 
                           total_num_runs_pm, total_num_runs_dpm,
                           y_max=GLOBAL_OSAI_Y_MAX)
    else:
        print("Skipping combined OSAI plot due to missing data for one or both experiments.")

    if not main_df_pm.empty and not main_df_dpm.empty:
        plot_combined_strategic_responses_multipanel(
            main_df_pm.copy(), main_df_dpm.copy(),
            total_num_runs_pm, total_num_runs_dpm,
            strategies_to_plot=STRATEGIES_FOR_COMBINED_PLOT
        )
    else:
        print("Skipping combined multi-panel strategic response plots due to missing data for one or both experiments.")
        
    print("\n--- Combined Visualizations Complete ---")
    print(f"Output plots are in: {OUTPUT_COMBINED_VIS_DIR}")
    print("\nScript finished.")
